from network.ISTANet import *
from network.ISTANetPlus import *
from network.DPDNN import *


def chooseModel(args):
    model_type = args.model_type
    layer_num = args.layer_num
    if model_type == 'ISTA_Net_plus':
        model = ISTANetplus(layer_num)
    if model_type == 'DPDNN':
        model = DPDNN(layer_num)
    else:
        assert False, "not support such model_type"
    return model